# src/pointer/run_pointer_cc.py
# Orchestrator for the Compact-Curvature translator (CC)

import os, json, argparse, pathlib, math
from typing import Dict, Any, List
import numpy as np
import yaml

from .translator_cc import CCConfig, cc_translate_and_fit, _npz_load_e0

def _ensure_dir(p: str):
    pathlib.Path(p).mkdir(parents=True, exist_ok=True)

def _load_yaml(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return yaml.safe_load(f)

def _cfg_to_cc(cfg: Dict[str, Any]) -> CCConfig:
    t = cfg.get("translator", "CC")
    if t.upper() != "CC":
        raise ValueError("pointer_cc runner expects translator='CC'")

    cc = cfg.get("cc", {})
    return CCConfig(
        operator         = cc.get("operator", "LoG"),
        sigma_list       = tuple(cc.get("sigma_list", [1,2,3])),
        normalize        = cc.get("normalize", "zscore"),
        threshold        = cc.get("threshold", "quantile:0.90"),
        connectivity     = int(cc.get("connectivity", 8)),
        morph_open       = int(cc.get("morph_open", 1)),
        morph_close      = int(cc.get("morph_close", 1)),
        fill_holes       = bool(cc.get("fill_holes", True)),
        remove_small_px  = int(cc.get("remove_small_px", 50)),
        keep             = cc.get("keep", "largest"),
        sheet3D          = bool(cc.get("sheet3D", True)),
        epsilon_soften   = cc.get("epsilon_soften", "0.5*sigma"),
        padding_factor   = int(cc.get("padding_factor", 0)),
        lambda_sweep     = tuple(cc.get("lambda_sweep", [0.2,0.5,1.0])),
        radial_bins_scheme = cc.get("radial_bins_scheme", "log"),
        radial_bins      = int(cc.get("radial_bins", 48)),
        fit_window_min   = cc.get("fit_window_min", "3*sigma"),
        fit_window_max_fracL = float(cc.get("fit_window_max_fracL", 0.30)),
        lensing_b_min    = int(cc.get("lensing_b_min", 8)),
        lensing_b_max    = int(cc.get("lensing_b_max", 90)),
        lensing_b_n      = int(cc.get("lensing_b_n", 32)),
        regression_weights = cc.get("regression_weights", "counts"),
    )

def _nested_path(data_dir: str, gauge: str, L: int, b: float, kappa: float, f: float, seed: int) -> str:
    return os.path.join(
        data_dir,
        gauge,
        f"L{L}",
        f"b{b}",
        f"k{float(kappa):.2f}",
        f"f{float(f):.2f}",
        f"seed{int(seed)}",
        "E0.npz"
    )

def _find_e0(data_dir: str, gauge: str, L: int, b: float, kappa: float, f: float, seed: int) -> str:
    """
    Prefer the nested layout; if not present, try flat naming fallback:
    E0_<gauge>_L<L>_b<b>_k<kappa>_f<f>_seed<seed>.npz
    """
    nested = _nested_path(data_dir, gauge, L, b, kappa, f, seed)
    if os.path.exists(nested):
        return nested
    flat = os.path.join(
        data_dir,
        f"E0_{gauge}_L{L}_b{b}_k{float(kappa):.2f}_f{float(f):.2f}_seed{int(seed)}.npz"
    )
    if os.path.exists(flat):
        return flat
    raise FileNotFoundError(f"Snapshot not found: {nested} (or {flat})")

def run_pointer_cc(config_path: str, out_dir: str):
    cfg = _load_yaml(config_path)

    # anchors
    gauges = cfg.get("gauges", ["SU3", "SU2"])
    Ls     = cfg.get("L", [256])
    b_list = cfg.get("b", [3.5])
    kappa_by_gauge = cfg.get("kappa", {"SU3":[1.00,0.75], "SU2":[1.00]})
    f_list = cfg.get("f", [0.00, 0.10, 0.30])
    seeds  = cfg.get("seeds", [0,1,2,3,4])

    data_dir = cfg.get("data_dir", "data/inputs")
    _ensure_dir(out_dir)

    cc_cfg = _cfg_to_cc(cfg)

    runs = 0
    for gauge in gauges:
        for L in Ls:
            for b in b_list:
                for kappa in kappa_by_gauge.get(gauge, []):
                    for f in f_list:
                        for seed in seeds:
                            # Locate input
                            e0_path = _find_e0(data_dir, gauge, int(L), float(b), float(kappa), float(f), int(seed))
                            # Load E0 robustly
                            E0 = _npz_load_e0(e0_path).astype(np.float64)
                            if E0.shape != (int(L), int(L)):
                                raise ValueError(f"E0 shape mismatch at {e0_path}: got {E0.shape}, expected {(L,L)}")

                            # Translate & fit
                            res = cc_translate_and_fit(E0, cc_cfg, int(L))

                            # Record output JSON
                            out = {
                                "sim": "POINTER_CC",
                                "gauge": gauge, "L": int(L), "b": float(b), "kappa": float(kappa),
                                "f": float(f), "seed": int(seed),
                                "translator": "CC", "cc_config": {
                                    # keep a small subset for provenance
                                    "operator": cc_cfg.operator,
                                    "sigma_list": list(cc_cfg.sigma_list),
                                    "threshold": cc_cfg.threshold,
                                    "connectivity": cc_cfg.connectivity,
                                    "epsilon_soften": cc_cfg.epsilon_soften,
                                    "lambda_sweep": list(cc_cfg.lambda_sweep),
                                    "radial_bins": cc_cfg.radial_bins,
                                    "fit_window_min": cc_cfg.fit_window_min,
                                    "fit_window_max_fracL": cc_cfg.fit_window_max_fracL
                                },
                                "metrics": res
                            }
                            fname = f"{gauge}_L{L}_b{b}_k{kappa:.2f}_f{f:.2f}_seed{seed}.json"
                            with open(os.path.join(out_dir, fname), "w") as fh:
                                json.dump(out, fh, indent=2)
                            runs += 1
    print(f"[pointer_cc] wrote {runs} JSON results to {out_dir}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True, help="Path to configs/pointer_cc.yaml")
    ap.add_argument("--output", required=False, default="runs_cc", help="Output directory for JSONs")
    args = ap.parse_args()
    run_pointer_cc(args.config, args.output)

if __name__ == "__main__":
    main()
